{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# High-level TF Example - tf.estimator.Estimator()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a much more concise example of using TF and hopefully is the way forward. The multi-GPU example will probably build off this since it has a very nice wrapper and good tensorboard support.\n",
    "\n",
    "See example: https://github.com/BobLiu20/Classification_Nets/blob/master/tensorflow/training/train_estimator.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "import tensorflow as tf\n",
    "from common.params import *\n",
    "from common.utils import *\n",
    "slim = tf.contrib.slim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Force one-gpu\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OS:  linux\n",
      "Python:  3.5.2 |Anaconda custom (64-bit)| (default, Jul  2 2016, 17:53:06) \n",
      "[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]\n",
      "Numpy:  1.14.1\n",
      "Tensorflow:  1.4.0\n",
      "GPU:  ['Tesla P100-PCIE-16GB', 'Tesla P100-PCIE-16GB']\n",
      "CUDA Version 8.0.61\n",
      "CuDNN Version  6.0.21\n"
     ]
    }
   ],
   "source": [
    "print(\"OS: \", sys.platform)\n",
    "print(\"Python: \", sys.version)\n",
    "print(\"Numpy: \", np.__version__)\n",
    "print(\"Tensorflow: \", tf.__version__)\n",
    "print(\"GPU: \", get_gpu_name())\n",
    "print(get_cuda_version())\n",
    "print(\"CuDNN Version \", get_cudnn_version())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_symbol(X, training, n_classes):\n",
    "    print(\"Training mode: \", training==True)\n",
    "    # Tensorflow requires a flag for training in dropout\n",
    "    conv1 = tf.layers.conv2d(X['features'], activation=tf.nn.relu, filters=50, kernel_size=(3, 3), \n",
    "                             padding='same', data_format='channels_first')\n",
    "    conv2 = tf.layers.conv2d(conv1, filters=50, kernel_size=(3, 3), \n",
    "                             padding='same', data_format='channels_first')\n",
    "    pool1 = tf.layers.max_pooling2d(conv2, pool_size=(2, 2), strides=(2, 2), \n",
    "                                    padding='valid', data_format='channels_first')\n",
    "    relu2 = tf.nn.relu(pool1)\n",
    "    drop1 = tf.layers.dropout(relu2, 0.25, training=training)\n",
    "    \n",
    "    conv3 = tf.layers.conv2d(drop1, activation=tf.nn.relu, filters=100, kernel_size=(3, 3), \n",
    "                             padding='same', data_format='channels_first')\n",
    "    conv4 = tf.layers.conv2d(conv3, filters=100, kernel_size=(3, 3), \n",
    "                             padding='same', data_format='channels_first')\n",
    "    pool2 = tf.layers.max_pooling2d(conv4, pool_size=(2, 2), strides=(2, 2), \n",
    "                                    padding='valid', data_format='channels_first')\n",
    "    relu4 = tf.nn.relu(pool2)\n",
    "    drop2 = tf.layers.dropout(relu4, 0.25, training=training)   \n",
    "    \n",
    "    flatten = tf.reshape(drop2, shape=[-1, 100*8*8])\n",
    "    fc1 = tf.layers.dense(flatten, 512, activation=tf.nn.relu)\n",
    "    drop3 = tf.layers.dropout(fc1, 0.5, training=training)\n",
    "    logits = tf.layers.dense(drop3, n_classes, name='output')\n",
    "    return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_fn(features, labels, mode, params):\n",
    "    # Create symbol\n",
    "    sym = create_symbol(X=features, \n",
    "                        training=(mode == tf.estimator.ModeKeys.TRAIN),\n",
    "                        n_classes=params['n_classes'])\n",
    "    # Predictions\n",
    "    predictions = tf.argmax(tf.concat(sym, 0), 1)   \n",
    "    # ModeKeys.PREDICT\n",
    "    if mode == tf.estimator.ModeKeys.PREDICT:\n",
    "        return tf.estimator.EstimatorSpec(mode=mode, \n",
    "                                          predictions={\"output\": predictions})\n",
    "    # Optimizer & Loss\n",
    "    optimizer = tf.train.MomentumOptimizer(learning_rate=params['lr'], \n",
    "                                           momentum=params['momentum'])\n",
    "    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=sym, labels=labels)\n",
    "    loss = tf.reduce_mean(xentropy)\n",
    "    # Eval metric ops\n",
    "    eval_metric_ops = {\"acc\": slim.metrics.streaming_accuracy(predictions, labels)}\n",
    "    return tf.estimator.EstimatorSpec(\n",
    "        mode=mode,\n",
    "        loss=loss,\n",
    "        train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()),\n",
    "        eval_metric_ops=eval_metric_ops)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing train set...\n",
      "Preparing test set...\n",
      "(50000, 3, 32, 32) (10000, 3, 32, 32) (50000,) (10000,)\n",
      "float32 float32 int32 int32\n",
      "CPU times: user 643 ms, sys: 581 ms, total: 1.22 s\n",
      "Wall time: 1.22 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Data into format for library\n",
    "x_train, x_test, y_train, y_test = cifar_for_library(channel_first=True)\n",
    "print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)\n",
    "print(x_train.dtype, x_test.dtype, y_train.dtype, y_test.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpq9y0ghxl\n",
      "INFO:tensorflow:Using config: {'_save_checkpoints_steps': None, '_num_worker_replicas': 1, '_task_type': 'worker', '_keep_checkpoint_every_n_hours': 10000, '_save_checkpoints_secs': 600, '_log_step_count_steps': 1000, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f814d269cf8>, '_is_chief': True, '_num_ps_replicas': 0, '_tf_random_seed': None, '_master': '', '_keep_checkpoint_max': 5, '_task_id': 0, '_save_summary_steps': 1000, '_model_dir': '/tmp/tmpq9y0ghxl', '_session_config': None, '_service': None}\n",
      "CPU times: user 0 ns, sys: 3.72 ms, total: 3.72 ms\n",
      "Wall time: 3.18 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Create Estimator\n",
    "nn = tf.estimator.Estimator(model_fn=model_fn,\n",
    "                            params={\"lr\":LR, \n",
    "                                    \"momentum\":MOMENTUM,\n",
    "                                    \"n_classes\":N_CLASSES},\n",
    "                            config=tf.estimator.RunConfig(\n",
    "                                log_step_count_steps=1000,\n",
    "                                save_summary_steps=1000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training mode:  True\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpq9y0ghxl/model.ckpt.\n",
      "INFO:tensorflow:loss = 2.3170414, step = 1\n",
      "INFO:tensorflow:loss = 2.2064767, step = 101 (0.897 sec)\n",
      "INFO:tensorflow:loss = 2.1269495, step = 201 (0.740 sec)\n",
      "INFO:tensorflow:loss = 1.786201, step = 301 (0.759 sec)\n",
      "INFO:tensorflow:loss = 1.8294327, step = 401 (0.734 sec)\n",
      "INFO:tensorflow:loss = 2.0215316, step = 501 (0.727 sec)\n",
      "INFO:tensorflow:loss = 1.8027611, step = 601 (0.729 sec)\n",
      "INFO:tensorflow:loss = 1.655209, step = 701 (0.727 sec)\n",
      "INFO:tensorflow:loss = 1.5047133, step = 801 (0.779 sec)\n",
      "INFO:tensorflow:loss = 1.364079, step = 901 (0.759 sec)\n",
      "INFO:tensorflow:global_step/sec: 131.782\n",
      "INFO:tensorflow:loss = 1.8369548, step = 1001 (0.737 sec)\n",
      "INFO:tensorflow:loss = 1.3392849, step = 1101 (0.744 sec)\n",
      "INFO:tensorflow:loss = 1.44081, step = 1201 (0.732 sec)\n",
      "INFO:tensorflow:loss = 1.1232271, step = 1301 (0.737 sec)\n",
      "INFO:tensorflow:loss = 1.246576, step = 1401 (0.735 sec)\n",
      "INFO:tensorflow:loss = 1.399048, step = 1501 (0.728 sec)\n",
      "INFO:tensorflow:loss = 1.3109409, step = 1601 (0.750 sec)\n",
      "INFO:tensorflow:loss = 1.0977027, step = 1701 (0.765 sec)\n",
      "INFO:tensorflow:loss = 1.2988272, step = 1801 (0.753 sec)\n",
      "INFO:tensorflow:loss = 1.2889335, step = 1901 (0.740 sec)\n",
      "INFO:tensorflow:global_step/sec: 134.623\n",
      "INFO:tensorflow:loss = 1.2646211, step = 2001 (0.743 sec)\n",
      "INFO:tensorflow:loss = 1.1483378, step = 2101 (0.735 sec)\n",
      "INFO:tensorflow:loss = 1.0285265, step = 2201 (0.779 sec)\n",
      "INFO:tensorflow:loss = 0.8870604, step = 2301 (0.733 sec)\n",
      "INFO:tensorflow:loss = 0.92534465, step = 2401 (0.739 sec)\n",
      "INFO:tensorflow:loss = 1.0498724, step = 2501 (0.781 sec)\n",
      "INFO:tensorflow:loss = 1.0085716, step = 2601 (0.741 sec)\n",
      "INFO:tensorflow:loss = 0.82484204, step = 2701 (0.732 sec)\n",
      "INFO:tensorflow:loss = 1.1081153, step = 2801 (0.740 sec)\n",
      "INFO:tensorflow:loss = 0.82805943, step = 2901 (0.725 sec)\n",
      "INFO:tensorflow:global_step/sec: 134.346\n",
      "INFO:tensorflow:loss = 0.9152268, step = 3001 (0.739 sec)\n",
      "INFO:tensorflow:loss = 0.8082077, step = 3101 (0.743 sec)\n",
      "INFO:tensorflow:loss = 0.7954086, step = 3201 (0.738 sec)\n",
      "INFO:tensorflow:loss = 0.9318168, step = 3301 (0.770 sec)\n",
      "INFO:tensorflow:loss = 0.7816889, step = 3401 (0.760 sec)\n",
      "INFO:tensorflow:loss = 0.7232264, step = 3501 (0.731 sec)\n",
      "INFO:tensorflow:loss = 0.7814284, step = 3601 (0.740 sec)\n",
      "INFO:tensorflow:loss = 0.7805284, step = 3701 (0.732 sec)\n",
      "INFO:tensorflow:loss = 0.83298635, step = 3801 (0.734 sec)\n",
      "INFO:tensorflow:loss = 0.7007756, step = 3901 (0.759 sec)\n",
      "INFO:tensorflow:global_step/sec: 134.317\n",
      "INFO:tensorflow:loss = 0.9259589, step = 4001 (0.738 sec)\n",
      "INFO:tensorflow:loss = 1.026039, step = 4101 (0.747 sec)\n",
      "INFO:tensorflow:loss = 0.7151191, step = 4201 (0.789 sec)\n",
      "INFO:tensorflow:loss = 0.63691247, step = 4301 (0.754 sec)\n",
      "INFO:tensorflow:loss = 1.104894, step = 4401 (0.730 sec)\n",
      "INFO:tensorflow:loss = 0.6340748, step = 4501 (0.736 sec)\n",
      "INFO:tensorflow:loss = 0.8548136, step = 4601 (0.736 sec)\n",
      "INFO:tensorflow:loss = 0.7710654, step = 4701 (0.734 sec)\n",
      "INFO:tensorflow:loss = 0.6889061, step = 4801 (0.745 sec)\n",
      "INFO:tensorflow:loss = 0.7177516, step = 4901 (0.738 sec)\n",
      "INFO:tensorflow:global_step/sec: 134.093\n",
      "INFO:tensorflow:loss = 0.7304247, step = 5001 (0.750 sec)\n",
      "INFO:tensorflow:loss = 0.6389053, step = 5101 (0.791 sec)\n",
      "INFO:tensorflow:loss = 0.63810164, step = 5201 (0.743 sec)\n",
      "INFO:tensorflow:loss = 0.638374, step = 5301 (0.735 sec)\n",
      "INFO:tensorflow:loss = 0.9848243, step = 5401 (0.740 sec)\n",
      "INFO:tensorflow:loss = 0.9720929, step = 5501 (0.739 sec)\n",
      "INFO:tensorflow:loss = 0.7556629, step = 5601 (0.734 sec)\n",
      "INFO:tensorflow:loss = 0.74508137, step = 5701 (0.734 sec)\n",
      "INFO:tensorflow:loss = 0.65362096, step = 5801 (0.735 sec)\n",
      "INFO:tensorflow:loss = 0.6445347, step = 5901 (0.778 sec)\n",
      "INFO:tensorflow:global_step/sec: 133.275\n",
      "INFO:tensorflow:loss = 0.60195994, step = 6001 (0.773 sec)\n",
      "INFO:tensorflow:loss = 0.7363389, step = 6101 (0.742 sec)\n",
      "INFO:tensorflow:loss = 0.56986856, step = 6201 (0.729 sec)\n",
      "INFO:tensorflow:loss = 0.5275197, step = 6301 (0.736 sec)\n",
      "INFO:tensorflow:loss = 0.65698683, step = 6401 (0.737 sec)\n",
      "INFO:tensorflow:loss = 0.6537148, step = 6501 (0.734 sec)\n",
      "INFO:tensorflow:loss = 0.79957855, step = 6601 (0.741 sec)\n",
      "INFO:tensorflow:loss = 0.37656224, step = 6701 (0.745 sec)\n",
      "INFO:tensorflow:loss = 0.4103039, step = 6801 (0.784 sec)\n",
      "INFO:tensorflow:loss = 0.5148346, step = 6901 (0.757 sec)\n",
      "INFO:tensorflow:global_step/sec: 134.287\n",
      "INFO:tensorflow:loss = 0.77914476, step = 7001 (0.741 sec)\n",
      "INFO:tensorflow:loss = 0.6495944, step = 7101 (0.730 sec)\n",
      "INFO:tensorflow:loss = 0.5565534, step = 7201 (0.767 sec)\n",
      "INFO:tensorflow:loss = 0.51911247, step = 7301 (0.730 sec)\n",
      "INFO:tensorflow:loss = 0.46255705, step = 7401 (0.729 sec)\n",
      "INFO:tensorflow:loss = 0.3688233, step = 7501 (0.744 sec)\n",
      "INFO:tensorflow:loss = 0.62947345, step = 7601 (0.761 sec)\n",
      "INFO:tensorflow:loss = 0.42201072, step = 7701 (0.784 sec)\n",
      "INFO:tensorflow:loss = 0.55022347, step = 7801 (0.739 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 7813 into /tmp/tmpq9y0ghxl/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 0.8183278.\n",
      "CPU times: user 56.4 s, sys: 14.2 s, total: 1min 10s\n",
      "Wall time: 1min\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.estimator.estimator.Estimator at 0x7f814d269978>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "# Train Estimator: 60s\n",
    "nn.train(input_fn=tf.estimator.inputs.numpy_input_fn(\n",
    "    x={\"features\":x_train}, \n",
    "    y=y_train,\n",
    "    batch_size=BATCHSIZE,\n",
    "    num_epochs=10, \n",
    "    shuffle=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training mode:  False\n",
      "INFO:tensorflow:Starting evaluation at 2018-03-10-02:01:29\n",
      "INFO:tensorflow:Restoring parameters from /tmp/tmpq9y0ghxl/model.ckpt-7813\n",
      "INFO:tensorflow:Finished evaluation at 2018-03-10-02:01:30\n",
      "INFO:tensorflow:Saving dict for global step 7813: acc = 0.7815, global_step = 7813, loss = 0.6418523\n",
      "CPU times: user 965 ms, sys: 129 ms, total: 1.09 s\n",
      "Wall time: 1.03 s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'acc': 0.7815, 'global_step': 7813, 'loss': 0.6418523}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "# Evaluate estimator: 1s\n",
    "# Accuracy: 78%\n",
    "nn.evaluate(input_fn=tf.estimator.inputs.numpy_input_fn(\n",
    "    x={\"features\":x_test}, \n",
    "    y=y_test,\n",
    "    shuffle=False))"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
